import os
import torch.optim as optim
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from argparse import ArgumentParser

from hg_model import MoocTrialNet
from config import HP
from hg_dataset import HandGestureDataset


class HgTrainer:

    def __init__(self, args):
        # 配置日志工具
        self.logger = SummaryWriter("./log")

        # 配置训练的模型
        self.model = MoocTrialNet().to(HP.device)

        # 配置损失函数
        self.loss_fn = nn.CrossEntropyLoss()

        # 配置优化器
        self.opt = optim.Adam(self.model.parameters(), HP.init_lr)

        # 配置训练和评估的数据加载器
        train_set = HandGestureDataset(HP.metadata_train_path)
        self.train_loader = DataLoader(
            train_set, batch_size=HP.batch_size, shuffle=True, drop_last=True
        )
        dev_set = HandGestureDataset(HP.metadata_eval_path)
        self.dev_loader = DataLoader(
            dev_set, batch_size=HP.batch_size, shuffle=True, drop_last=False
        )

        # 配置训练状态
        self.start_epoch = 0
        self.step = 0
        self.best_loss = 0
        if args.c:
            checkpoint = torch.load(args.c)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.opt.load_state_dict(checkpoint['optimizer_state_dict'])
            self.start_epoch = checkpoint['epoch']

    def train(self):
        self.model.train()
        print("model training start, total # epochs: {}".format(HP.epochs))
        for epoch in range(self.start_epoch, HP.epochs):
            for batch in tqdm(self.train_loader,
                              desc='Start Epoch: %d, Steps: %d' % (epoch, len(self.train_loader) / HP.batch_size)):
                # 取出数据
                imgs, labels = batch
                # 正向传播
                pred_labels = self.model(imgs)

                # 反向求导、计算损失、参数更新
                self.opt.zero_grad()
                loss = self.loss_fn(pred_labels, labels.to(HP.device))
                loss.backward()
                self.opt.step()

                # 记录训练损失
                self.logger.add_scalar('Loss/Train', loss, self.step)

                # 评估、并记录评估损失
                if not self.step % HP.verbose_step:
                    eval_loss = self.evaluate()
                    self.logger.add_scalar('Loss/Dev', eval_loss, self.step)

                # 存储模型
                self.save(epoch)

                self.step += 1
                self.logger.flush()

    def evaluate(self):
        self.model.eval()
        sum_loss = 0
        with torch.no_grad():
            for batch in self.dev_loader:
                imgs, labels = batch
                pred_labels = self.model(imgs)
                loss = self.loss_fn(pred_labels, labels.to(HP.device))
                sum_loss += loss.item()
        self.model.train()
        return sum_loss / len(self.dev_loader)

    def save(self, current_epoch):
        if not self.step % HP.save_step:
            save_dict = {
                'epoch': current_epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.opt.state_dict()
            }
            torch.save(save_dict, os.path.join(HP.model_save_root,
                                               HP.model_path % (current_epoch, self.step)))


if __name__ == '__main__':
    parser = ArgumentParser(description="Model Training")
    parser.add_argument(
        '--c',
        default=None,
        type=str,
        help='train from scratch or resume training'
    )
    args = parser.parse_args()
    HgTrainer(args).train()
